CS767_Pengfei_Ma_Project

Data preparation

I. Import data set

In [3]:
import numpy as np
import pandas as pd
In [4]:
heart1 = pd.read_csv("heart.csv")

heart2 = pd.read_csv('processed_cleveland.data', sep=",", 
                     names=["age", "sex", "cp", "trestbps", "chol", "fbs", "restecg", 
                            "thalach", "exang", "oldpeak", "slope", "ca", "thal", "target"])

heart3 = pd.read_csv('reprocessed_hungarian.data', sep=" ", 
                     names=["age", "sex", "cp", "trestbps", "chol", "fbs", "restecg", 
                            "thalach", "exang", "oldpeak", "slope", "ca", "thal", "target"]).abs()
In [5]:
heart1.head()
Out[5]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
0 63 1 3 145 233 1 0 150 0 2.3 0 0 1 1
1 37 1 2 130 250 0 1 187 0 3.5 0 0 2 1
2 41 0 1 130 204 0 0 172 0 1.4 2 0 2 1
3 56 1 1 120 236 0 1 178 0 0.8 2 0 2 1
4 57 0 0 120 354 0 1 163 1 0.6 2 0 2 1
In [6]:
heart2.head()
Out[6]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
0 63.0 1.0 1.0 145.0 233.0 1.0 2.0 150.0 0.0 2.3 3.0 0.0 6.0 0
1 67.0 1.0 4.0 160.0 286.0 0.0 2.0 108.0 1.0 1.5 2.0 3.0 3.0 2
2 67.0 1.0 4.0 120.0 229.0 0.0 2.0 129.0 1.0 2.6 2.0 2.0 7.0 1
3 37.0 1.0 3.0 130.0 250.0 0.0 0.0 187.0 0.0 3.5 3.0 0.0 3.0 0
4 41.0 0.0 2.0 130.0 204.0 0.0 2.0 172.0 0.0 1.4 1.0 0.0 3.0 0
In [7]:
heart3.head()
Out[7]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
0 40 1 2 140 289 0 0 172 0 0.0 9 9 9 0
1 49 0 3 160 180 0 0 156 0 1.0 2 9 9 1
2 37 1 2 130 283 0 1 98 0 0.0 9 9 9 0
3 48 0 4 138 214 0 0 108 1 1.5 2 9 9 3
4 54 1 3 150 9 0 0 122 0 0.0 9 9 9 0

II. Process the data set

In [8]:
# 1. merge all three data set
heart_total = pd.concat([heart1, heart2, heart3], ignore_index=True)
In [9]:
heart_total_array = heart_total.values # convert dataframe into numpy array
In [10]:
X = heart_total_array[:,0:12] # all other columns except target column
Y = heart_total_array[:,13] # target column
In [11]:
# 2. Scale the X set
from sklearn import preprocessing

min_max_scaler = preprocessing.MinMaxScaler()
X_scale = min_max_scaler.fit_transform(X)
In [12]:
# 3. split train and test set in 70/30.
from sklearn.model_selection import train_test_split

X_train, X_test, Y_train, Y_test = train_test_split(X_scale, Y, test_size=0.3)

Data visualization

In [14]:
import plotly.offline as py
import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as pyo
from plotly.subplots import make_subplots
import seaborn as sns
import matplotlib.pyplot as plt

I. Ages and sex

1. Histogram of patient age counts

In [13]:
female = heart_total.loc[heart_total['sex'] == 0]['age']
male = heart_total.loc[heart_total['sex'] == 1]['age']

fig = go.Figure()

fig.add_trace(go.Histogram(
    x=female,
    histnorm='percent',
    name='Female', 
    marker_color='#EB89B5',
    opacity=0.75
))
fig.add_trace(go.Histogram(
    x=male,
    histnorm='percent',
    name='Male',
    marker_color='#330C73',
    opacity=0.75
))

fig.update_layout(
    title_text='Patient age counts', 
    xaxis_title_text='Age', 
    yaxis_title_text='Count', 
    bargap=0.2, 
    bargroupgap=0.1
)

fig.show()

Histogram of patient age counts with heart attack risk comparsion

In [14]:
female_all = heart_total.loc[heart_total['sex'] == 0]
female = female_all.loc[female_all['target'] != 0]['age']
male_all = heart_total.loc[heart_total['sex'] == 1]
male = male_all.loc[male_all['target'] != 0]['age']

fig = make_subplots(rows=1, cols=2, subplot_titles=('Female','Male'))

fig.add_trace(
    go.Histogram(
    x=female_all['age'],
    histnorm='percent',
    name='All people', 
    marker_color='#EB89B5',
    opacity=0.75
), row=1, col=1)

fig.add_trace(
    go.Histogram(
    x=male_all['age'],
    histnorm='percent',
    name='All people',
    marker_color='#EB89B5',
    opacity=0.75
), row=1, col=2)

fig.add_trace(
    go.Histogram(
    x=female,
    histnorm='percent',
    name='Female with heart attack risk', 
    marker_color='#330C73',
    opacity=0.75
), row=1, col=1)

fig.add_trace(
    go.Histogram(
    x=male,
    histnorm='percent',
    name='Male with heart attack risk',
    marker_color='#330C73',
    opacity=0.75
), row=1, col=2)

fig.update_layout(
    title_text='Patient age counts with heart attack risk comparsion', 
    xaxis_title_text='Age', 
    yaxis_title_text='Count', 
    bargap=0.2, 
    bargroupgap=0.1
)

fig.show()

3. Boxplot of ages

In [15]:
trace0 = go.Box(y=heart_total['age'], name='Age for all people', marker_color = 'blue')
trace1 = go.Box(y=male_all['age'], name='Age for male', marker_color = 'red')
trace2 = go.Box(y=female_all['age'], name='Age for female', marker_color = 'green')

data = [trace0, trace1, trace2]
layout = go.Layout(title='Boxplots for ages and sex', 
                   xaxis_title="Age and sex",yaxis_title="Ages", hovermode='x')

fig = go.Figure(data=data, layout=layout)
fig.update_layout(legend_title_text='Ages and sex label')

fig.show()

II. Correlation Heatmap

In [16]:
fig=plt.figure(figsize=(12,8), dpi= 100, facecolor='w', edgecolor='k')
M0 = sns.heatmap(heart_total.corr(), annot = True).set_title('Heart Attack Features correlation')

plt.savefig("Heart Attack Features correlation heatmap.png")

Multi-layers Perceptron

I. Model Design

In [16]:
import tensorflow as tf
from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from functools import partial
from keras import models
from keras import layers
from sklearn.model_selection import cross_val_score
In [18]:
def create_network(optimizer = 'rmsprop'):
    MaxNormDense = partial(keras.layers.Dense,
                           activation="selu", kernel_initializer="lecun_normal",
                           kernel_constraint=keras.constraints.max_norm(1.)
                           )

    model = models.Sequential()
    model.add(layers.Dense(200, activation='relu', input_shape=(12,)))
    model.add(layers.Dense(200, activation='relu'))
    model.add(layers.Dense(200, activation='relu'))
    model.add(layers.Dense(200, activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))

    model.compile(optimizer=optimizer,
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    return model

neural_network = KerasClassifier(build_fn=create_network,verbose=0)

1. Optimizers: : adam, nadam, sgd, rmsprop, adamax, adagrad

2. Epoch: 100, 200, 300

In [18]:
from sklearn.model_selection import GridSearchCV
In [21]:
epochs = [100, 200, 300]
optimizers = ['rmsprop', 'nadam', 'adam', 'sgd', 'adamax', 'adagrad']

# Create hyperparameter options
hyperparameters = dict(optimizer = optimizers, epochs=epochs)

# Create grid search
grid = GridSearchCV(estimator=neural_network, param_grid=hyperparameters, cv=3) 

# Fit gird search
grid_output = grid.fit(X_test, Y_test)
print(grid_output)
print(grid_output.best_params_)
print(grid_output.best_score_)
GridSearchCV(cv=3,
             estimator=<keras.wrappers.scikit_learn.KerasClassifier object at 0x7ff64ea07b10>,
             param_grid={'epochs': [100, 200, 300],
                         'optimizer': ['rmsprop', 'nadam', 'adam', 'sgd',
                                       'adamax', 'adagrad']})
{'epochs': 100, 'optimizer': 'adam'}
0.524094432592392

III. Best model collection

In [22]:
model = Sequential([
    Dense(200, activation='relu', input_shape=(12,)),
    Dense(200, activation='relu'),
    Dense(200, activation='relu'),
    Dense(200, activation='relu'),
    Dense(1, activation='sigmoid')
    ])
In [23]:
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])
In [24]:
hist = model.fit(X_train, Y_train, 
                 batch_size=32, 
                 epochs=100, 
                 validation_split=0.3)
Epoch 1/100
14/14 [==============================] - 0s 11ms/step - loss: 0.5004 - accuracy: 0.3016 - val_loss: 0.3534 - val_accuracy: 0.2698
Epoch 2/100
14/14 [==============================] - 0s 3ms/step - loss: 0.2009 - accuracy: 0.2948 - val_loss: -0.2852 - val_accuracy: 0.2646
Epoch 3/100
14/14 [==============================] - 0s 3ms/step - loss: -0.6441 - accuracy: 0.3946 - val_loss: -1.7723 - val_accuracy: 0.3280
Epoch 4/100
14/14 [==============================] - 0s 3ms/step - loss: -2.7377 - accuracy: 0.3696 - val_loss: -6.2115 - val_accuracy: 0.3968
Epoch 5/100
14/14 [==============================] - 0s 3ms/step - loss: -8.7932 - accuracy: 0.3968 - val_loss: -18.2126 - val_accuracy: 0.4021
Epoch 6/100
14/14 [==============================] - 0s 3ms/step - loss: -23.8661 - accuracy: 0.4195 - val_loss: -45.8736 - val_accuracy: 0.4550
Epoch 7/100
14/14 [==============================] - 0s 3ms/step - loss: -63.1841 - accuracy: 0.4263 - val_loss: -111.7402 - val_accuracy: 0.4074
Epoch 8/100
14/14 [==============================] - 0s 3ms/step - loss: -161.3972 - accuracy: 0.4240 - val_loss: -278.9972 - val_accuracy: 0.3545
Epoch 9/100
14/14 [==============================] - 0s 3ms/step - loss: -350.7342 - accuracy: 0.3787 - val_loss: -630.1064 - val_accuracy: 0.4603
Epoch 10/100
14/14 [==============================] - 0s 3ms/step - loss: -738.8523 - accuracy: 0.3923 - val_loss: -1392.9103 - val_accuracy: 0.4127
Epoch 11/100
14/14 [==============================] - 0s 3ms/step - loss: -1535.1074 - accuracy: 0.4127 - val_loss: -2525.4551 - val_accuracy: 0.3598
Epoch 12/100
14/14 [==============================] - 0s 3ms/step - loss: -2761.2400 - accuracy: 0.3923 - val_loss: -4546.9214 - val_accuracy: 0.4392
Epoch 13/100
14/14 [==============================] - 0s 3ms/step - loss: -5208.9331 - accuracy: 0.3991 - val_loss: -7443.7886 - val_accuracy: 0.3598
Epoch 14/100
14/14 [==============================] - 0s 3ms/step - loss: -8748.6172 - accuracy: 0.3832 - val_loss: -12611.0303 - val_accuracy: 0.4021
Epoch 15/100
14/14 [==============================] - 0s 3ms/step - loss: -14555.0576 - accuracy: 0.4127 - val_loss: -20338.1094 - val_accuracy: 0.3386
Epoch 16/100
14/14 [==============================] - 0s 3ms/step - loss: -22871.7754 - accuracy: 0.3946 - val_loss: -33436.7539 - val_accuracy: 0.4550
Epoch 17/100
14/14 [==============================] - 0s 3ms/step - loss: -36208.3594 - accuracy: 0.4376 - val_loss: -48099.2852 - val_accuracy: 0.3439
Epoch 18/100
14/14 [==============================] - 0s 3ms/step - loss: -52792.1055 - accuracy: 0.3832 - val_loss: -73538.5469 - val_accuracy: 0.4180
Epoch 19/100
14/14 [==============================] - 0s 3ms/step - loss: -78177.0234 - accuracy: 0.3878 - val_loss: -106225.3750 - val_accuracy: 0.4127
Epoch 20/100
14/14 [==============================] - 0s 3ms/step - loss: -112453.8203 - accuracy: 0.4331 - val_loss: -148034.3125 - val_accuracy: 0.3810
Epoch 21/100
14/14 [==============================] - 0s 3ms/step - loss: -158289.2344 - accuracy: 0.3810 - val_loss: -209423.8281 - val_accuracy: 0.4233
Epoch 22/100
14/14 [==============================] - 0s 3ms/step - loss: -213401.0312 - accuracy: 0.3968 - val_loss: -289870.9688 - val_accuracy: 0.3968
Epoch 23/100
14/14 [==============================] - 0s 3ms/step - loss: -284498.2500 - accuracy: 0.4354 - val_loss: -375572.6562 - val_accuracy: 0.3757
Epoch 24/100
14/14 [==============================] - 0s 3ms/step - loss: -378071.5312 - accuracy: 0.3651 - val_loss: -482645.1250 - val_accuracy: 0.3810
Epoch 25/100
14/14 [==============================] - 0s 3ms/step - loss: -480273.5312 - accuracy: 0.3991 - val_loss: -627297.2500 - val_accuracy: 0.4127
Epoch 26/100
14/14 [==============================] - 0s 3ms/step - loss: -621657.7500 - accuracy: 0.4308 - val_loss: -784949.6875 - val_accuracy: 0.4127
Epoch 27/100
14/14 [==============================] - 0s 3ms/step - loss: -776267.3750 - accuracy: 0.4059 - val_loss: -987932.6875 - val_accuracy: 0.3862
Epoch 28/100
14/14 [==============================] - 0s 3ms/step - loss: -969509.1250 - accuracy: 0.4376 - val_loss: -1227083.0000 - val_accuracy: 0.3915
Epoch 29/100
14/14 [==============================] - 0s 3ms/step - loss: -1196420.8750 - accuracy: 0.4195 - val_loss: -1501122.3750 - val_accuracy: 0.3915
Epoch 30/100
14/14 [==============================] - 0s 3ms/step - loss: -1460672.6250 - accuracy: 0.4150 - val_loss: -1825004.2500 - val_accuracy: 0.3968
Epoch 31/100
14/14 [==============================] - 0s 3ms/step - loss: -1759636.5000 - accuracy: 0.4263 - val_loss: -2199303.7500 - val_accuracy: 0.4021
Epoch 32/100
14/14 [==============================] - 0s 3ms/step - loss: -2110922.5000 - accuracy: 0.3855 - val_loss: -2639089.5000 - val_accuracy: 0.4021
Epoch 33/100
14/14 [==============================] - 0s 3ms/step - loss: -2536601.7500 - accuracy: 0.4444 - val_loss: -3079071.2500 - val_accuracy: 0.4074
Epoch 34/100
14/14 [==============================] - 0s 3ms/step - loss: -3009209.2500 - accuracy: 0.4172 - val_loss: -3650109.2500 - val_accuracy: 0.3915
Epoch 35/100
14/14 [==============================] - 0s 3ms/step - loss: -3578315.5000 - accuracy: 0.4104 - val_loss: -4359113.0000 - val_accuracy: 0.4180
Epoch 36/100
14/14 [==============================] - 0s 3ms/step - loss: -4199866.0000 - accuracy: 0.4354 - val_loss: -5094000.5000 - val_accuracy: 0.4180
Epoch 37/100
14/14 [==============================] - 0s 3ms/step - loss: -4958380.0000 - accuracy: 0.4036 - val_loss: -5943813.5000 - val_accuracy: 0.3862
Epoch 38/100
14/14 [==============================] - 0s 3ms/step - loss: -5773290.0000 - accuracy: 0.3968 - val_loss: -6905241.0000 - val_accuracy: 0.4074
Epoch 39/100
14/14 [==============================] - 0s 3ms/step - loss: -6682593.0000 - accuracy: 0.4535 - val_loss: -7963675.0000 - val_accuracy: 0.3968
Epoch 40/100
14/14 [==============================] - 0s 3ms/step - loss: -7694129.0000 - accuracy: 0.4059 - val_loss: -9096472.0000 - val_accuracy: 0.4127
Epoch 41/100
14/14 [==============================] - 0s 3ms/step - loss: -8852197.0000 - accuracy: 0.4036 - val_loss: -10572594.0000 - val_accuracy: 0.3862
Epoch 42/100
14/14 [==============================] - 0s 3ms/step - loss: -10082817.0000 - accuracy: 0.4127 - val_loss: -12058211.0000 - val_accuracy: 0.3915
Epoch 43/100
14/14 [==============================] - 0s 3ms/step - loss: -11587468.0000 - accuracy: 0.4036 - val_loss: -13907982.0000 - val_accuracy: 0.4233
Epoch 44/100
14/14 [==============================] - 0s 3ms/step - loss: -13370164.0000 - accuracy: 0.4172 - val_loss: -15904910.0000 - val_accuracy: 0.4233
Epoch 45/100
14/14 [==============================] - 0s 3ms/step - loss: -14980238.0000 - accuracy: 0.4082 - val_loss: -18130654.0000 - val_accuracy: 0.3915
Epoch 46/100
14/14 [==============================] - 0s 3ms/step - loss: -17282162.0000 - accuracy: 0.4331 - val_loss: -20470374.0000 - val_accuracy: 0.4021
Epoch 47/100
14/14 [==============================] - 0s 3ms/step - loss: -19602868.0000 - accuracy: 0.4308 - val_loss: -23142968.0000 - val_accuracy: 0.4233
Epoch 48/100
14/14 [==============================] - 0s 3ms/step - loss: -22195468.0000 - accuracy: 0.4104 - val_loss: -26109296.0000 - val_accuracy: 0.4127
Epoch 49/100
14/14 [==============================] - 0s 3ms/step - loss: -25103220.0000 - accuracy: 0.4104 - val_loss: -29512304.0000 - val_accuracy: 0.3915
Epoch 50/100
14/14 [==============================] - 0s 3ms/step - loss: -28095764.0000 - accuracy: 0.3991 - val_loss: -33153034.0000 - val_accuracy: 0.3862
Epoch 51/100
14/14 [==============================] - 0s 3ms/step - loss: -31315162.0000 - accuracy: 0.4172 - val_loss: -37002540.0000 - val_accuracy: 0.4127
Epoch 52/100
14/14 [==============================] - 0s 3ms/step - loss: -35151724.0000 - accuracy: 0.4104 - val_loss: -41161596.0000 - val_accuracy: 0.3862
Epoch 53/100
14/14 [==============================] - 0s 3ms/step - loss: -39067832.0000 - accuracy: 0.4172 - val_loss: -45277592.0000 - val_accuracy: 0.3915
Epoch 54/100
14/14 [==============================] - 0s 3ms/step - loss: -43031832.0000 - accuracy: 0.4014 - val_loss: -49975604.0000 - val_accuracy: 0.3915
Epoch 55/100
14/14 [==============================] - 0s 3ms/step - loss: -47425740.0000 - accuracy: 0.3900 - val_loss: -55582892.0000 - val_accuracy: 0.4286
Epoch 56/100
14/14 [==============================] - 0s 3ms/step - loss: -52575948.0000 - accuracy: 0.4150 - val_loss: -61220524.0000 - val_accuracy: 0.4021
Epoch 57/100
14/14 [==============================] - 0s 3ms/step - loss: -57230604.0000 - accuracy: 0.3991 - val_loss: -67296424.0000 - val_accuracy: 0.3915
Epoch 58/100
14/14 [==============================] - 0s 3ms/step - loss: -63446224.0000 - accuracy: 0.4172 - val_loss: -73749040.0000 - val_accuracy: 0.3915
Epoch 59/100
14/14 [==============================] - 0s 3ms/step - loss: -69305336.0000 - accuracy: 0.4172 - val_loss: -80342632.0000 - val_accuracy: 0.3862
Epoch 60/100
14/14 [==============================] - 0s 3ms/step - loss: -75694448.0000 - accuracy: 0.4104 - val_loss: -88085568.0000 - val_accuracy: 0.3915
Epoch 61/100
14/14 [==============================] - 0s 3ms/step - loss: -82921744.0000 - accuracy: 0.3946 - val_loss: -95954456.0000 - val_accuracy: 0.3915
Epoch 62/100
14/14 [==============================] - 0s 3ms/step - loss: -89285488.0000 - accuracy: 0.3991 - val_loss: -104033384.0000 - val_accuracy: 0.3915
Epoch 63/100
14/14 [==============================] - 0s 3ms/step - loss: -97795576.0000 - accuracy: 0.4195 - val_loss: -112279152.0000 - val_accuracy: 0.4180
Epoch 64/100
14/14 [==============================] - 0s 3ms/step - loss: -105688936.0000 - accuracy: 0.4218 - val_loss: -120982848.0000 - val_accuracy: 0.4127
Epoch 65/100
14/14 [==============================] - 0s 3ms/step - loss: -113827968.0000 - accuracy: 0.3991 - val_loss: -129857088.0000 - val_accuracy: 0.3915
Epoch 66/100
14/14 [==============================] - 0s 3ms/step - loss: -121722136.0000 - accuracy: 0.4014 - val_loss: -141028048.0000 - val_accuracy: 0.3862
Epoch 67/100
14/14 [==============================] - 0s 3ms/step - loss: -131945776.0000 - accuracy: 0.4422 - val_loss: -151065568.0000 - val_accuracy: 0.3915
Epoch 68/100
14/14 [==============================] - 0s 3ms/step - loss: -141650272.0000 - accuracy: 0.4263 - val_loss: -161670576.0000 - val_accuracy: 0.3862
Epoch 69/100
14/14 [==============================] - 0s 3ms/step - loss: -151653344.0000 - accuracy: 0.3991 - val_loss: -173336016.0000 - val_accuracy: 0.4286
Epoch 70/100
14/14 [==============================] - 0s 3ms/step - loss: -162006720.0000 - accuracy: 0.4104 - val_loss: -185840944.0000 - val_accuracy: 0.3862
Epoch 71/100
14/14 [==============================] - 0s 3ms/step - loss: -173773904.0000 - accuracy: 0.3900 - val_loss: -198360656.0000 - val_accuracy: 0.4233
Epoch 72/100
14/14 [==============================] - 0s 3ms/step - loss: -185576608.0000 - accuracy: 0.4036 - val_loss: -212213216.0000 - val_accuracy: 0.3862
Epoch 73/100
14/14 [==============================] - 0s 3ms/step - loss: -199030416.0000 - accuracy: 0.4218 - val_loss: -226336624.0000 - val_accuracy: 0.3862
Epoch 74/100
14/14 [==============================] - 0s 3ms/step - loss: -210557088.0000 - accuracy: 0.4036 - val_loss: -241496352.0000 - val_accuracy: 0.3915
Epoch 75/100
14/14 [==============================] - 0s 3ms/step - loss: -225111680.0000 - accuracy: 0.4195 - val_loss: -257087088.0000 - val_accuracy: 0.4074
Epoch 76/100
14/14 [==============================] - 0s 3ms/step - loss: -240824928.0000 - accuracy: 0.4104 - val_loss: -272596448.0000 - val_accuracy: 0.3862
Epoch 77/100
14/14 [==============================] - 0s 3ms/step - loss: -255509632.0000 - accuracy: 0.4150 - val_loss: -291192192.0000 - val_accuracy: 0.4127
Epoch 78/100
14/14 [==============================] - 0s 3ms/step - loss: -271826912.0000 - accuracy: 0.4082 - val_loss: -309015040.0000 - val_accuracy: 0.4339
Epoch 79/100
14/14 [==============================] - 0s 3ms/step - loss: -290153504.0000 - accuracy: 0.4059 - val_loss: -329000224.0000 - val_accuracy: 0.4286
Epoch 80/100
14/14 [==============================] - 0s 3ms/step - loss: -308233184.0000 - accuracy: 0.4172 - val_loss: -349024224.0000 - val_accuracy: 0.3915
Epoch 81/100
14/14 [==============================] - 0s 3ms/step - loss: -325059104.0000 - accuracy: 0.4014 - val_loss: -372598336.0000 - val_accuracy: 0.4286
Epoch 82/100
14/14 [==============================] - 0s 3ms/step - loss: -347076736.0000 - accuracy: 0.3900 - val_loss: -395104544.0000 - val_accuracy: 0.3862
Epoch 83/100
14/14 [==============================] - 0s 3ms/step - loss: -365764672.0000 - accuracy: 0.4331 - val_loss: -418080928.0000 - val_accuracy: 0.3862
Epoch 84/100
14/14 [==============================] - 0s 3ms/step - loss: -389790208.0000 - accuracy: 0.4172 - val_loss: -441453888.0000 - val_accuracy: 0.4233
Epoch 85/100
14/14 [==============================] - 0s 3ms/step - loss: -408730400.0000 - accuracy: 0.4399 - val_loss: -465607584.0000 - val_accuracy: 0.4127
Epoch 86/100
14/14 [==============================] - 0s 3ms/step - loss: -432196224.0000 - accuracy: 0.4150 - val_loss: -492981984.0000 - val_accuracy: 0.4286
Epoch 87/100
14/14 [==============================] - 0s 3ms/step - loss: -458526592.0000 - accuracy: 0.3991 - val_loss: -520137920.0000 - val_accuracy: 0.4074
Epoch 88/100
14/14 [==============================] - 0s 3ms/step - loss: -483550144.0000 - accuracy: 0.4263 - val_loss: -547108864.0000 - val_accuracy: 0.3862
Epoch 89/100
14/14 [==============================] - 0s 3ms/step - loss: -510597152.0000 - accuracy: 0.4127 - val_loss: -575453376.0000 - val_accuracy: 0.4127
Epoch 90/100
14/14 [==============================] - 0s 3ms/step - loss: -534837888.0000 - accuracy: 0.4263 - val_loss: -605937472.0000 - val_accuracy: 0.4233
Epoch 91/100
14/14 [==============================] - 0s 3ms/step - loss: -564826752.0000 - accuracy: 0.4399 - val_loss: -636682176.0000 - val_accuracy: 0.3862
Epoch 92/100
14/14 [==============================] - 0s 3ms/step - loss: -593484864.0000 - accuracy: 0.4014 - val_loss: -669433408.0000 - val_accuracy: 0.4233
Epoch 93/100
14/14 [==============================] - 0s 3ms/step - loss: -623868480.0000 - accuracy: 0.4399 - val_loss: -703274496.0000 - val_accuracy: 0.3862
Epoch 94/100
14/14 [==============================] - 0s 3ms/step - loss: -652200896.0000 - accuracy: 0.4195 - val_loss: -739927872.0000 - val_accuracy: 0.3915
Epoch 95/100
14/14 [==============================] - 0s 3ms/step - loss: -688888576.0000 - accuracy: 0.4150 - val_loss: -777078976.0000 - val_accuracy: 0.3862
Epoch 96/100
14/14 [==============================] - 0s 3ms/step - loss: -719186816.0000 - accuracy: 0.4308 - val_loss: -815443392.0000 - val_accuracy: 0.3968
Epoch 97/100
14/14 [==============================] - 0s 3ms/step - loss: -748649088.0000 - accuracy: 0.3991 - val_loss: -854169792.0000 - val_accuracy: 0.3915
Epoch 98/100
14/14 [==============================] - 0s 3ms/step - loss: -795517120.0000 - accuracy: 0.4082 - val_loss: -898567104.0000 - val_accuracy: 0.3862
Epoch 99/100
14/14 [==============================] - 0s 3ms/step - loss: -834017024.0000 - accuracy: 0.4240 - val_loss: -939794880.0000 - val_accuracy: 0.4127
Epoch 100/100
14/14 [==============================] - 0s 3ms/step - loss: -870798208.0000 - accuracy: 0.4059 - val_loss: -985035904.0000 - val_accuracy: 0.3862
In [25]:
fig = make_subplots(rows=1, cols=2, subplot_titles=('Loss','Accuracy'))

fig.add_trace(
    go.Scatter(x=np.array(range(101)), y=hist.history['loss'],
                    mode='lines',marker_color='red',name='Train loss'), row=1, col=1)

fig.add_trace(
    go.Scatter(x=np.array(range(101)), y=hist.history['val_loss'],
                    mode='lines',marker_color='blue',name='Val loss'), row=1, col=1)

fig.add_trace(
    go.Scatter(x=np.array(range(101)), y=hist.history['accuracy'],
                    mode='lines',marker_color='orange',name='Train accuracy'), row=1, col=2)

fig.add_trace(
    go.Scatter(x=np.array(range(101)), y=hist.history['val_accuracy'],
                    mode='lines',marker_color='purple',name='Val_accuracy'), row=1, col=2)

fig.update_layout(
    title_text='Loss and accuracy of the best model', 
    xaxis_title_text='Epoch', 
    bargap=0.2, 
    bargroupgap=0.1,
    hovermode='x'
)

fig.show()

IV. Model evaluation

In [26]:
model.evaluate(X_test, Y_test)
9/9 [==============================] - 0s 891us/step - loss: nan - accuracy: 0.3838      
Out[26]:
[nan, 0.38376384973526]

2. The best model with 10 neurons each layer

In [17]:
def create_network(optimizer = 'rmsprop'):
    MaxNormDense = partial(keras.layers.Dense,
                           activation="selu", kernel_initializer="lecun_normal",
                           kernel_constraint=keras.constraints.max_norm(1.)
                           )

    model = models.Sequential()
    model.add(layers.Dense(10, activation='relu', input_shape=(12,)))
    model.add(layers.Dense(10, activation='relu'))
    model.add(layers.Dense(10, activation='relu'))
    model.add(layers.Dense(10, activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))

    model.compile(optimizer=optimizer,
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    return model

neural_network = KerasClassifier(build_fn=create_network,verbose=0)

II. Grid Search for neuron = 10

1. Optimizers: : adam, nadam, sgd, rmsprop, adamax, adagrad

2. Epoch: 100, 200, 300

In [20]:
epochs = [100, 200, 300]
optimizers = ['rmsprop', 'nadam', 'adam', 'sgd', 'adamax', 'adagrad']

# Create hyperparameter options
hyperparameters = dict(optimizer = optimizers, epochs=epochs)

# Create grid search
grid = GridSearchCV(estimator=neural_network, param_grid=hyperparameters, cv=3) 

# Fit gird search
grid_output = grid.fit(X_test, Y_test)
print(grid_output)
print(grid_output.best_params_)
print(grid_output.best_score_)
GridSearchCV(cv=3,
             estimator=<keras.wrappers.scikit_learn.KerasClassifier object at 0x7fd74ba42bd0>,
             param_grid={'epochs': [100, 200, 300],
                         'optimizer': ['rmsprop', 'nadam', 'adam', 'sgd',
                                       'adamax', 'adagrad']})
{'epochs': 100, 'optimizer': 'rmsprop'}
0.38148148854573566

Best model with 10 neuron each layer

In [21]:
model = Sequential([
    Dense(10, activation='relu', input_shape=(12,)),
    Dense(10, activation='relu'),
    Dense(10, activation='relu'),
    Dense(10, activation='relu'),
    Dense(1, activation='sigmoid')
    ])
In [22]:
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])
In [23]:
hist = model.fit(X_train, Y_train, 
                 batch_size=32, 
                 epochs=100, 
                 validation_split=0.3)
Epoch 1/100
14/14 [==============================] - 1s 10ms/step - loss: 0.6368 - accuracy: 0.2971 - val_loss: 0.6044 - val_accuracy: 0.2751
Epoch 2/100
14/14 [==============================] - 0s 2ms/step - loss: 0.5998 - accuracy: 0.2993 - val_loss: 0.5670 - val_accuracy: 0.2751
Epoch 3/100
14/14 [==============================] - 0s 2ms/step - loss: 0.5721 - accuracy: 0.2993 - val_loss: 0.5376 - val_accuracy: 0.2751
Epoch 4/100
14/14 [==============================] - 0s 3ms/step - loss: 0.5477 - accuracy: 0.2993 - val_loss: 0.5094 - val_accuracy: 0.2751
Epoch 5/100
14/14 [==============================] - 0s 3ms/step - loss: 0.5285 - accuracy: 0.2993 - val_loss: 0.4893 - val_accuracy: 0.2751
Epoch 6/100
14/14 [==============================] - 0s 3ms/step - loss: 0.5125 - accuracy: 0.2993 - val_loss: 0.4709 - val_accuracy: 0.2751
Epoch 7/100
14/14 [==============================] - 0s 3ms/step - loss: 0.4969 - accuracy: 0.2993 - val_loss: 0.4529 - val_accuracy: 0.2751
Epoch 8/100
14/14 [==============================] - 0s 3ms/step - loss: 0.4848 - accuracy: 0.2993 - val_loss: 0.4354 - val_accuracy: 0.2751
Epoch 9/100
14/14 [==============================] - 0s 3ms/step - loss: 0.4690 - accuracy: 0.2993 - val_loss: 0.4133 - val_accuracy: 0.2751
Epoch 10/100
14/14 [==============================] - 0s 3ms/step - loss: 0.4499 - accuracy: 0.2993 - val_loss: 0.3894 - val_accuracy: 0.2751
Epoch 11/100
14/14 [==============================] - 0s 3ms/step - loss: 0.4297 - accuracy: 0.2993 - val_loss: 0.3628 - val_accuracy: 0.2751
Epoch 12/100
14/14 [==============================] - 0s 3ms/step - loss: 0.4074 - accuracy: 0.2993 - val_loss: 0.3295 - val_accuracy: 0.2751
Epoch 13/100
14/14 [==============================] - 0s 3ms/step - loss: 0.3772 - accuracy: 0.2993 - val_loss: 0.2902 - val_accuracy: 0.2804
Epoch 14/100
14/14 [==============================] - 0s 3ms/step - loss: 0.3463 - accuracy: 0.3016 - val_loss: 0.2513 - val_accuracy: 0.2804
Epoch 15/100
14/14 [==============================] - 0s 3ms/step - loss: 0.3170 - accuracy: 0.3039 - val_loss: 0.2141 - val_accuracy: 0.2751
Epoch 16/100
14/14 [==============================] - 0s 3ms/step - loss: 0.2792 - accuracy: 0.3175 - val_loss: 0.1654 - val_accuracy: 0.2804
Epoch 17/100
14/14 [==============================] - 0s 3ms/step - loss: 0.2413 - accuracy: 0.3197 - val_loss: 0.1158 - val_accuracy: 0.2751
Epoch 18/100
14/14 [==============================] - 0s 3ms/step - loss: 0.1989 - accuracy: 0.3379 - val_loss: 0.0690 - val_accuracy: 0.2698
Epoch 19/100
14/14 [==============================] - 0s 3ms/step - loss: 0.1534 - accuracy: 0.3379 - val_loss: 0.0089 - val_accuracy: 0.3016
Epoch 20/100
14/14 [==============================] - 0s 3ms/step - loss: 0.1134 - accuracy: 0.3469 - val_loss: -0.0629 - val_accuracy: 0.2910
Epoch 21/100
14/14 [==============================] - 0s 3ms/step - loss: 0.0533 - accuracy: 0.3537 - val_loss: -0.1319 - val_accuracy: 0.3175
Epoch 22/100
14/14 [==============================] - 0s 3ms/step - loss: -0.0032 - accuracy: 0.3469 - val_loss: -0.2111 - val_accuracy: 0.3228
Epoch 23/100
14/14 [==============================] - 0s 2ms/step - loss: -0.0715 - accuracy: 0.3537 - val_loss: -0.2872 - val_accuracy: 0.3386
Epoch 24/100
14/14 [==============================] - 0s 3ms/step - loss: -0.1370 - accuracy: 0.3537 - val_loss: -0.3995 - val_accuracy: 0.3333
Epoch 25/100
14/14 [==============================] - 0s 2ms/step - loss: -0.2219 - accuracy: 0.3764 - val_loss: -0.5064 - val_accuracy: 0.3228
Epoch 26/100
14/14 [==============================] - 0s 3ms/step - loss: -0.3064 - accuracy: 0.3696 - val_loss: -0.5992 - val_accuracy: 0.3545
Epoch 27/100
14/14 [==============================] - 0s 3ms/step - loss: -0.3811 - accuracy: 0.3968 - val_loss: -0.7335 - val_accuracy: 0.3598
Epoch 28/100
14/14 [==============================] - 0s 3ms/step - loss: -0.4865 - accuracy: 0.3968 - val_loss: -0.8727 - val_accuracy: 0.3651
Epoch 29/100
14/14 [==============================] - 0s 3ms/step - loss: -0.5916 - accuracy: 0.4082 - val_loss: -1.0219 - val_accuracy: 0.3598
Epoch 30/100
14/14 [==============================] - 0s 3ms/step - loss: -0.7223 - accuracy: 0.4082 - val_loss: -1.1833 - val_accuracy: 0.4180
Epoch 31/100
14/14 [==============================] - 0s 3ms/step - loss: -0.8440 - accuracy: 0.4104 - val_loss: -1.3544 - val_accuracy: 0.4127
Epoch 32/100
14/14 [==============================] - 0s 3ms/step - loss: -0.9685 - accuracy: 0.4376 - val_loss: -1.5491 - val_accuracy: 0.3862
Epoch 33/100
14/14 [==============================] - 0s 3ms/step - loss: -1.1172 - accuracy: 0.4263 - val_loss: -1.7228 - val_accuracy: 0.4180
Epoch 34/100
14/14 [==============================] - 0s 3ms/step - loss: -1.2529 - accuracy: 0.4512 - val_loss: -1.9363 - val_accuracy: 0.4074
Epoch 35/100
14/14 [==============================] - 0s 3ms/step - loss: -1.4264 - accuracy: 0.4444 - val_loss: -2.1465 - val_accuracy: 0.3704
Epoch 36/100
14/14 [==============================] - 0s 4ms/step - loss: -1.5972 - accuracy: 0.4308 - val_loss: -2.4032 - val_accuracy: 0.4392
Epoch 37/100
14/14 [==============================] - 0s 3ms/step - loss: -1.7958 - accuracy: 0.4580 - val_loss: -2.6898 - val_accuracy: 0.4233
Epoch 38/100
14/14 [==============================] - 0s 3ms/step - loss: -1.9754 - accuracy: 0.4490 - val_loss: -2.9506 - val_accuracy: 0.4339
Epoch 39/100
14/14 [==============================] - 0s 2ms/step - loss: -2.1265 - accuracy: 0.4490 - val_loss: -3.1847 - val_accuracy: 0.4339
Epoch 40/100
14/14 [==============================] - 0s 2ms/step - loss: -2.3639 - accuracy: 0.4671 - val_loss: -3.4753 - val_accuracy: 0.4180
Epoch 41/100
14/14 [==============================] - 0s 2ms/step - loss: -2.5922 - accuracy: 0.4444 - val_loss: -3.8146 - val_accuracy: 0.4286
Epoch 42/100
14/14 [==============================] - 0s 2ms/step - loss: -2.8199 - accuracy: 0.4512 - val_loss: -4.1710 - val_accuracy: 0.4233
Epoch 43/100
14/14 [==============================] - 0s 2ms/step - loss: -3.0969 - accuracy: 0.4467 - val_loss: -4.5234 - val_accuracy: 0.4233
Epoch 44/100
14/14 [==============================] - 0s 2ms/step - loss: -3.3864 - accuracy: 0.4535 - val_loss: -4.9293 - val_accuracy: 0.4074
Epoch 45/100
14/14 [==============================] - 0s 2ms/step - loss: -3.6765 - accuracy: 0.4331 - val_loss: -5.3079 - val_accuracy: 0.4286
Epoch 46/100
14/14 [==============================] - 0s 2ms/step - loss: -3.9995 - accuracy: 0.4444 - val_loss: -5.6766 - val_accuracy: 0.4497
Epoch 47/100
14/14 [==============================] - 0s 2ms/step - loss: -4.2963 - accuracy: 0.4694 - val_loss: -6.2265 - val_accuracy: 0.4180
Epoch 48/100
14/14 [==============================] - 0s 3ms/step - loss: -4.6568 - accuracy: 0.4671 - val_loss: -6.7244 - val_accuracy: 0.4127
Epoch 49/100
14/14 [==============================] - 0s 2ms/step - loss: -5.0283 - accuracy: 0.4626 - val_loss: -7.2142 - val_accuracy: 0.4074
Epoch 50/100
14/14 [==============================] - 0s 2ms/step - loss: -5.4144 - accuracy: 0.4512 - val_loss: -7.7134 - val_accuracy: 0.4233
Epoch 51/100
14/14 [==============================] - 0s 2ms/step - loss: -5.8191 - accuracy: 0.4580 - val_loss: -8.3089 - val_accuracy: 0.4233
Epoch 52/100
14/14 [==============================] - 0s 2ms/step - loss: -6.2873 - accuracy: 0.4626 - val_loss: -8.9417 - val_accuracy: 0.4180
Epoch 53/100
14/14 [==============================] - 0s 2ms/step - loss: -6.7471 - accuracy: 0.4580 - val_loss: -9.6171 - val_accuracy: 0.4180
Epoch 54/100
14/14 [==============================] - 0s 2ms/step - loss: -7.3647 - accuracy: 0.4717 - val_loss: -10.3444 - val_accuracy: 0.4074
Epoch 55/100
14/14 [==============================] - 0s 2ms/step - loss: -7.8844 - accuracy: 0.4399 - val_loss: -11.0095 - val_accuracy: 0.4497
Epoch 56/100
14/14 [==============================] - 0s 2ms/step - loss: -8.3465 - accuracy: 0.4694 - val_loss: -11.8321 - val_accuracy: 0.4286
Epoch 57/100
14/14 [==============================] - 0s 2ms/step - loss: -9.0375 - accuracy: 0.4739 - val_loss: -12.7890 - val_accuracy: 0.4127
Epoch 58/100
14/14 [==============================] - 0s 2ms/step - loss: -9.7490 - accuracy: 0.4649 - val_loss: -13.8254 - val_accuracy: 0.4233
Epoch 59/100
14/14 [==============================] - 0s 2ms/step - loss: -10.4922 - accuracy: 0.4467 - val_loss: -14.6120 - val_accuracy: 0.4497
Epoch 60/100
14/14 [==============================] - 0s 3ms/step - loss: -11.2196 - accuracy: 0.4717 - val_loss: -15.7778 - val_accuracy: 0.4074
Epoch 61/100
14/14 [==============================] - 0s 2ms/step - loss: -12.0694 - accuracy: 0.4580 - val_loss: -16.7397 - val_accuracy: 0.4444
Epoch 62/100
14/14 [==============================] - 0s 2ms/step - loss: -12.7927 - accuracy: 0.4694 - val_loss: -17.9731 - val_accuracy: 0.4233
Epoch 63/100
14/14 [==============================] - 0s 2ms/step - loss: -13.7186 - accuracy: 0.4558 - val_loss: -19.0721 - val_accuracy: 0.4339
Epoch 64/100
14/14 [==============================] - 0s 2ms/step - loss: -14.6516 - accuracy: 0.4739 - val_loss: -20.3700 - val_accuracy: 0.4180
Epoch 65/100
14/14 [==============================] - 0s 2ms/step - loss: -15.7172 - accuracy: 0.4558 - val_loss: -21.7584 - val_accuracy: 0.4233
Epoch 66/100
14/14 [==============================] - 0s 2ms/step - loss: -16.8008 - accuracy: 0.4603 - val_loss: -23.1625 - val_accuracy: 0.4392
Epoch 67/100
14/14 [==============================] - 0s 2ms/step - loss: -17.9869 - accuracy: 0.4762 - val_loss: -24.7578 - val_accuracy: 0.4286
Epoch 68/100
14/14 [==============================] - 0s 2ms/step - loss: -19.1507 - accuracy: 0.4603 - val_loss: -26.2611 - val_accuracy: 0.4339
Epoch 69/100
14/14 [==============================] - 0s 2ms/step - loss: -20.3139 - accuracy: 0.4558 - val_loss: -27.5590 - val_accuracy: 0.4444
Epoch 70/100
14/14 [==============================] - 0s 2ms/step - loss: -21.5194 - accuracy: 0.4580 - val_loss: -29.2306 - val_accuracy: 0.4444
Epoch 71/100
14/14 [==============================] - 0s 2ms/step - loss: -22.9312 - accuracy: 0.4739 - val_loss: -31.1452 - val_accuracy: 0.4127
Epoch 72/100
14/14 [==============================] - 0s 2ms/step - loss: -24.2589 - accuracy: 0.4535 - val_loss: -32.8387 - val_accuracy: 0.4392
Epoch 73/100
14/14 [==============================] - 0s 2ms/step - loss: -25.8358 - accuracy: 0.4671 - val_loss: -34.9538 - val_accuracy: 0.4180
Epoch 74/100
14/14 [==============================] - 0s 2ms/step - loss: -27.3318 - accuracy: 0.4558 - val_loss: -37.0411 - val_accuracy: 0.4233
Epoch 75/100
14/14 [==============================] - 0s 2ms/step - loss: -28.9691 - accuracy: 0.4603 - val_loss: -39.1720 - val_accuracy: 0.4286
Epoch 76/100
14/14 [==============================] - 0s 2ms/step - loss: -30.8820 - accuracy: 0.4535 - val_loss: -41.2631 - val_accuracy: 0.4603
Epoch 77/100
14/14 [==============================] - 0s 2ms/step - loss: -32.7062 - accuracy: 0.4512 - val_loss: -43.9731 - val_accuracy: 0.4550
Epoch 78/100
14/14 [==============================] - 0s 2ms/step - loss: -34.8362 - accuracy: 0.4603 - val_loss: -46.9273 - val_accuracy: 0.4392
Epoch 79/100
14/14 [==============================] - 0s 2ms/step - loss: -36.8563 - accuracy: 0.4603 - val_loss: -49.5188 - val_accuracy: 0.4656
Epoch 80/100
14/14 [==============================] - 0s 2ms/step - loss: -39.0161 - accuracy: 0.4739 - val_loss: -52.4452 - val_accuracy: 0.4497
Epoch 81/100
14/14 [==============================] - 0s 2ms/step - loss: -41.2122 - accuracy: 0.4717 - val_loss: -55.4500 - val_accuracy: 0.4497
Epoch 82/100
14/14 [==============================] - 0s 2ms/step - loss: -43.1863 - accuracy: 0.4717 - val_loss: -58.0236 - val_accuracy: 0.4709
Epoch 83/100
14/14 [==============================] - 0s 2ms/step - loss: -45.5355 - accuracy: 0.4830 - val_loss: -60.8920 - val_accuracy: 0.4709
Epoch 84/100
14/14 [==============================] - 0s 2ms/step - loss: -47.7896 - accuracy: 0.4785 - val_loss: -64.0184 - val_accuracy: 0.4550
Epoch 85/100
14/14 [==============================] - 0s 2ms/step - loss: -50.3584 - accuracy: 0.4853 - val_loss: -67.2781 - val_accuracy: 0.4815
Epoch 86/100
14/14 [==============================] - 0s 2ms/step - loss: -53.0734 - accuracy: 0.4921 - val_loss: -71.0659 - val_accuracy: 0.4392
Epoch 87/100
14/14 [==============================] - 0s 2ms/step - loss: -55.4818 - accuracy: 0.4853 - val_loss: -74.2583 - val_accuracy: 0.4339
Epoch 88/100
14/14 [==============================] - 0s 2ms/step - loss: -58.0147 - accuracy: 0.4762 - val_loss: -77.0354 - val_accuracy: 0.4603
Epoch 89/100
14/14 [==============================] - 0s 2ms/step - loss: -60.8674 - accuracy: 0.4830 - val_loss: -80.2426 - val_accuracy: 0.4815
Epoch 90/100
14/14 [==============================] - 0s 2ms/step - loss: -63.8277 - accuracy: 0.4943 - val_loss: -85.2019 - val_accuracy: 0.4339
Epoch 91/100
14/14 [==============================] - 0s 2ms/step - loss: -66.8954 - accuracy: 0.4785 - val_loss: -89.1475 - val_accuracy: 0.4444
Epoch 92/100
14/14 [==============================] - 0s 2ms/step - loss: -70.2816 - accuracy: 0.4830 - val_loss: -93.2389 - val_accuracy: 0.4603
Epoch 93/100
14/14 [==============================] - 0s 2ms/step - loss: -73.6190 - accuracy: 0.4762 - val_loss: -98.1667 - val_accuracy: 0.4444
Epoch 94/100
14/14 [==============================] - 0s 2ms/step - loss: -77.4965 - accuracy: 0.4785 - val_loss: -102.4406 - val_accuracy: 0.4603
Epoch 95/100
14/14 [==============================] - 0s 2ms/step - loss: -81.0781 - accuracy: 0.4921 - val_loss: -108.0041 - val_accuracy: 0.4444
Epoch 96/100
14/14 [==============================] - 0s 3ms/step - loss: -84.9484 - accuracy: 0.4830 - val_loss: -112.5224 - val_accuracy: 0.4656
Epoch 97/100
14/14 [==============================] - 0s 3ms/step - loss: -88.6827 - accuracy: 0.4739 - val_loss: -117.7566 - val_accuracy: 0.4550
Epoch 98/100
14/14 [==============================] - 0s 2ms/step - loss: -93.1647 - accuracy: 0.4785 - val_loss: -123.3086 - val_accuracy: 0.4444
Epoch 99/100
14/14 [==============================] - 0s 2ms/step - loss: -97.2833 - accuracy: 0.4717 - val_loss: -127.4789 - val_accuracy: 0.4815
Epoch 100/100
14/14 [==============================] - 0s 2ms/step - loss: -100.9244 - accuracy: 0.4943 - val_loss: -134.0545 - val_accuracy: 0.4603
In [24]:
fig = make_subplots(rows=1, cols=2, subplot_titles=('Loss','Accuracy'))

fig.add_trace(
    go.Scatter(x=np.array(range(101)), y=hist.history['loss'],
                    mode='lines',marker_color='red',name='Train loss'), row=1, col=1)

fig.add_trace(
    go.Scatter(x=np.array(range(101)), y=hist.history['val_loss'],
                    mode='lines',marker_color='blue',name='Val loss'), row=1, col=1)

fig.add_trace(
    go.Scatter(x=np.array(range(101)), y=hist.history['accuracy'],
                    mode='lines',marker_color='orange',name='Train accuracy'), row=1, col=2)

fig.add_trace(
    go.Scatter(x=np.array(range(101)), y=hist.history['val_accuracy'],
                    mode='lines',marker_color='purple',name='Val_accuracy'), row=1, col=2)

fig.update_layout(
    title_text='Loss and accuracy of the best model', 
    xaxis_title_text='Epoch', 
    bargap=0.2, 
    bargroupgap=0.1,
    hovermode='x'
)

fig.show()

Model Evaluation

In [25]:
model.evaluate(X_test, Y_test)
9/9 [==============================] - 0s 730us/step - loss: -108.8652 - accuracy: 0.4111
Out[25]:
[-108.86524200439453, 0.41111111640930176]
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 

Classifiers

I. Logistic regression

In [1]:
from sklearn.linear_model import LogisticRegression
In [69]:
log_reg_classifier = LogisticRegression()
log_reg_classifier.fit(X_train,Y_train)

prediction = log_reg_classifier.predict(X_test)
correct_rate = np.mean(prediction == Y_test)
In [70]:
print("The accuracy of logistic regression is:",correct_rate)
The accuracy of logistic regression is: 0.44814814814814813

II. Gaussian Naïve Bayes

In [71]:
from sklearn.naive_bayes import GaussianNB
In [72]:
NB_classifier = GaussianNB().fit(X_train,Y_train)
prediction = NB_classifier.predict(X_test)
correct_rate = np.mean(prediction== Y_test)
In [73]:
print("The accuracy of Gaussian Naïve Bayes is:",correct_rate)
The accuracy of Gaussian Naïve Bayes is: 0.4666666666666667

III. Gaussian SVM

In [74]:
from sklearn import svm
from sklearn.preprocessing import StandardScaler
In [75]:
scaler = StandardScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train)
In [76]:
svm_classifier = svm.SVC(kernel ='linear')
svm_classifier.fit(X_train,Y_train)

prediction = svm_classifier.predict(X_test)
correct_rate = svm_classifier.score(X_test, Y_test)
In [77]:
print("The accuracy of Gaussian SVM is:",correct_rate)
The accuracy of Gaussian SVM is: 0.5148148148148148

Conclusion

After 100 epochs, MLP provided 38.38% accuracy. Logistic regression provided 44.81%, Gaussian Naïve Bayes provided 46.67%, and Gaussian SVM provided 51.48% accuracy. It is clear to see that Gaussian SVM provided the highest accuracy and MLP neural network provided the least accuracy. However, it does not mean that neural network is not as reliable as classifiers.

There are two possible reasons why MLP did not perform better than classifiers. First of all, the data is linear. This is quilt important. ANN have the ability to learn and model non-linear and complex relationship and neural networks are good to model with nonlinear data with large number of inputs. On the other hand, regression models and classifiers would perform better on linear data since these models are based on statistical model.

Secondly, the data set is not high-dimensional. Neural networks are best for situations where the data is high-dimensional like images. The data set used is a regular linear data set. So, the shape of the data set is perfectly fit the regression models and classifiers.

In conclusion, back to the research question, two conditions would be the best for neural network. Neural network would perform the best for the high dimensional data or non-linear data.

In [ ]: